import torch
import matplotlib.pyplot as plt
import random
import copy
import torch.optim.lr_scheduler as lr_scheduler
from Trace import Covariance
from Visualization import funcaverage

def SGD(input, LossFunctions, eps, lr, decay_rate, bs, seed):
    #torch.manual_seed(seed)
    #random.seed(seed)
    x = copy.deepcopy(input)
    AccCov = torch.zeros(2, 2)
    optimizer = torch.optim.SGD([x], lr=lr)
    lmbda = lambda epoch: decay_rate ** epoch
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda)
    PositiveHits = 0
    Traj = torch.zeros(2, eps + 1)
    S = list(range(len(LossFunctions)))
    for ep in range(eps):
        Traj[:, ep] = x.detach()
        optimizer.zero_grad()
        # Subsample the loss functions and construct the loss
        loss = 0
        for k in random.sample(S, bs):
            #print(k)
            loss += LossFunctions[k](x)
        #print(x)
        loss /= bs
        loss.backward()
        #if (ep+1) % 200 == 0:
            #gradient = x.grad
            #print("The grad is {}".format(gradient))
        optimizer.step()
        scheduler.step()
        #AccCov += Covariance(x, LossFunctions) * (optimizer.param_groups[0]["lr"] ** 2)
        if x[0] > 0:
            PositiveHits += 1
    loss = funcaverage(LossFunctions)
    # No use for H
    H = AccCov
    #H = torch.autograd.functional.hessian(loss, x)
    #J = torch.autograd.grad(loss, [x, y], create_graph=True)
    #J = torch.tensor(J).view(1, -1)
    #print("J is {}".format(J))
    #H = torch.autograd.grad(J, [x, y], retain_graph=True)
    #print("H is {}".format(H))
    #print("x is {}".format(x.detach()))
    Traj[:, eps] = x.detach()
    return x, PositiveHits / eps, Traj, H, AccCov